package util

import (
	"gitee.com/kklt1996/data-structure/common"
	"math/rand"
)

/*
	常用的排序函数
*/
type SortFunc func(slice []interface{}, reverse bool, comparator common.CompareFunc)

/*
	默认的排序方法
*/
func Sort(sortAble common.SortAble, reverse bool, comparator common.CompareFunc) {
	SortWithFunc(sortAble, reverse, comparator, MergeSort)
}

/*
	可以选择排序方法的排序方法
*/
func SortWithFunc(sortAble common.SortAble, reverse bool, comparator common.CompareFunc, sortFunc SortFunc) {
	if sortAble == nil {
		return
	}
	slice := sortAble.ToSlice()
	sortFunc(slice, reverse, comparator)
	i := 0
	sortAble.Iterator(func(iterator common.Iterator) bool {
		_ = iterator.Set(slice[i])
		i++
		return false
	})
}

/*
	O(log(n^2))
	选择排序: 遍历数组,不断寻找索引以及索引以后的元素中最小元素和当前索引位置元素进行交换.
	comparator　如果集合中的元素实现了common.CompareAble接口，或者是int float64 string　类型,comparator可以为空
				否则必须传入非nil值
	reverse 是否结果取反
*/
func SelectionSort(slice []interface{}, reverse bool, comparator common.CompareFunc) {
	for i, _ := range slice {
		minIndex := i
		// 寻找i以及后面元素中最小的索引
		for j := i + 1; j < len(slice); j++ {
			if compare(slice[j], slice[minIndex], reverse, comparator) < 0 {
				minIndex = j
			}
		}
		// 最小的元素和i位置元素进行交换
		swap(slice, i, minIndex)
	}
}

/*
	O(log(n^2))
	插入排序: 从index=1开始遍历数组,不断将索引位置元素和索引之前的元素进行比较,找到合适的插入位置
	如果数据是有序的,那么插入排序效率就是O(n)
*/
func InsertSort(slice []interface{}, reverse bool, comparator common.CompareFunc) {
	insertSort(slice, reverse, comparator, 0, len(slice)-1)
}

/*
	对数组的区间[l,r]进行插入排序
*/
func insertSort(slice []interface{}, reverse bool, comparator common.CompareFunc, l int, r int) {
	for i := l + 1; i <= r; i++ {
		// 寻找合适的插入位置
		e := slice[i]
		j := i - 1
		for ; j >= l && compare(e, slice[j], reverse, comparator) < 0; j-- {
			// 如果j位置元素比i位置元素大,那么就需要将j位置元素向后移动一个位置
			slice[j+1] = slice[j]
		}
		// 将第i个元素放置在第一个比第i个元素大的元素的后面
		slice[j+1] = e
	}
}

/*
	O(n*log(n))
	归并排序
*/
func MergeSort(slice []interface{}, reverse bool, comparator common.CompareFunc) {
	mergeSort(slice, reverse, comparator, 0, len(slice)-1)
}

/*
	对[l,r]区间内的元素进行排序
*/
func mergeSort(slice []interface{}, reverse bool, comparator common.CompareFunc, l int, r int) {
	if r-l <= 15 {
		// 递归的终止条件
		insertSort(slice, reverse, comparator, l, r)
		return
	}

	mid := l + (r-l)/2
	// 对拆分的两部分进行排序
	mergeSort(slice, reverse, comparator, l, mid)
	mergeSort(slice, reverse, comparator, mid+1, r)

	// 如果合并之前是无序的,对排序后的结果进行合并
	if compare(slice[mid], slice[mid+1], reverse, comparator) > 0 {
		merge(slice, l, mid, r, reverse, comparator)
	}
}

/*
	O(n*log(n))
	自底向上的归并排序
*/
func MergeSortBU(slice []interface{}, reverse bool, comparator common.CompareFunc) {
	mergeSortBU(slice, reverse, comparator, len(slice))
}

func mergeSortBU(slice []interface{}, reverse bool, comparator common.CompareFunc, n int) {
	for sz := 1; sz <= n; sz = sz << 1 {
		for i := 0; i+sz < n; i += sz << 1 {
			if i+sz<<1-1 > n-1 {
				merge(slice, i, i+sz-1, n-1, reverse, comparator)
			} else {
				merge(slice, i, i+sz-1, i+sz<<1-1, reverse, comparator)
			}
		}
	}
}

/*
	合并两部分排序好的数组
*/
func merge(slice []interface{}, l int, mid int, r int, reverse bool, comparator common.CompareFunc) {
	// 开辟额外的空间复制要排序的元素
	auxCapacity := r - l + 1
	aux := make([]interface{}, auxCapacity, auxCapacity)
	for i := l; i <= r; i++ {
		aux[i-l] = slice[i]
	}

	// 对左右两部分排序好的数组进行合并
	i := l
	j := mid + 1
	for k := l; k <= r; k++ {
		if i > mid {
			// 左边索引超出范围
			slice[k] = aux[j-l]
			j++
		} else if j > r {
			// 右边的索引超出范围
			slice[k] = aux[i-l]
			i++
		} else if compare(aux[i-l], aux[j-l], reverse, comparator) < 0 {
			// 左边的元素小于右边的元素
			slice[k] = aux[i-l]
			i++
		} else {
			slice[k] = aux[j-l]
			j++
		}
	}
}

/*
	O(n*log(n))
	快速排序
*/
func QuickSort(slice []interface{}, reverse bool, comparator common.CompareFunc) {
	quickSort(slice, reverse, comparator, 0, len(slice)-1)
}

/*
	对[l,r]区间进行快速排序
*/
func quickSort(slice []interface{}, reverse bool, comparator common.CompareFunc, l int, r int) {
	if r-l <= 15 {
		// 小区间使用插入排序提高速度
		insertSort(slice, reverse, comparator, l, r)
		return
	}

	// 进行partition操作将l找到合适的位置
	lt, gt := partition(slice, reverse, comparator, l, r)

	quickSort(slice, reverse, comparator, l, lt)
	quickSort(slice, reverse, comparator, gt, r)
}

/*
	三路快排处理 arr[l,r]
	假定v是标定元素,将arr[l,r]分为 < v; ==v; >v 三部分,之后对 <v 和 >v 的部分进行三路快排序
*/
func partition(slice []interface{}, reverse bool, comparator common.CompareFunc, l int, r int) (int, int) {
	// 随机选择标定元素
	swap(slice, rand.Int()%(r-l+1)+l, l)
	// 小于v的最大索引
	lt := l
	// 大于v的最小索引
	gt := r + 1
	// 遍历数组[l+1,r],不断的将比v小的元素放到lt索引的左边,将比v大的元素放在gt的右边
	i := l + 1
	for i < gt {
		compareResult := compare(slice[i], slice[l], reverse, comparator)
		if compareResult < 0 {
			// 比v小的区间向右延伸
			swap(slice, lt+1, i)
			lt++
			i++
		} else if compareResult > 0 {
			// 比v大的区间向左延伸
			swap(slice, gt-1, i)
			gt--
			//	因为gt-1位置元素没有和l位置元素被比较过,所以不需要i++
		} else {
			//	==v区间向右延伸
			i++
		}
	}
	// 将l位置的元素和最后一个比l小的元素进行交换
	swap(slice, l, lt)
	lt--
	return lt, gt
}

/*
	O(n*log(n))
	堆排序,先将数组通过heapify的方式整理成堆,然后不断的取出堆顶部的元素,和数组尾部未排序的第一个元素进行交换
	因为不需要额外的空间,所以对于随机数组比归并排序要快.
*/
func HeapSort(slice []interface{}, reverse bool, compareFunc common.CompareFunc) {
	// 从最后一个叶子节点的父亲节点开始,执行siftDown操作,将数组调整成堆
	for i := getHeapParent(len(slice) - 1); i > -1; i-- {
		siftDown(slice, i, reverse, compareFunc, len(slice)-1)
	}
	// 对整理好的堆进行排序
	for i := len(slice) - 1; i > 0; i-- {
		// 将最最大堆中最大的元素和i位置元素进行交换
		swap(slice, 0, i)
		// 将i-1之前的数组堆进行siftDown操作,整理成符合最大堆的性质
		siftDown(slice, 0, reverse, compareFunc, i-1)
	}
}

func siftDown(slice []interface{}, index int, reverse bool, compareFunc common.CompareFunc, lastIndex int) {
	value := slice[index]
	for {
		// 获取左右孩子节点中最大的元素
		maxValueIndex := getHeapLeftChild(index)
		if maxValueIndex > lastIndex {
			break
		}
		maxValue := slice[maxValueIndex]
		rightChildIndex := getHeapRightChild(index)
		if rightChildIndex <= lastIndex {
			rightChildValue := slice[rightChildIndex]
			if compare(maxValue, rightChildValue, reverse, compareFunc) < 0 {
				maxValue = rightChildValue
				maxValueIndex++
			}
		}
		if compare(value, maxValue, reverse, compareFunc) < 0 {
			// 左右孩子节点中最大的节点比当前要siftDown的节点大,需要进行交换
			swap(slice, index, maxValueIndex)
			index = maxValueIndex
		} else {
			// 左右节点都比孩子节点小,满足最大堆的性质
			break
		}
	}
}

func getHeapParent(index int) int {
	return (index - 1) / 2
}

func getHeapLeftChild(index int) int {
	return 2*index + 1
}

func getHeapRightChild(index int) int {
	return 2*index + 2
}

func DefaultComparator(thisValue interface{}, compareValue interface{}) int {
	switch thisValue.(type) {
	case int:
		i := thisValue.(int)
		j := compareValue.(int)
		if i == j {
			return 0
		} else if i > j {
			return 1
		} else {
			return -1
		}
	case float64:
		i := thisValue.(float64)
		j := compareValue.(float64)
		if i == j {
			return 0
		} else if i > j {
			return 1
		} else {
			return -1
		}
	case string:
		i := thisValue.(float64)
		j := compareValue.(float64)
		if i == j {
			return 0
		} else if i > j {
			return 1
		} else {
			return -1
		}
	}
	return thisValue.(common.CompareAble).CompareTo(compareValue.(common.CompareAble))
}

/*
	对两个值进行比较
*/
func compare(thisValue interface{}, compareValue interface{}, reverse bool, comparator common.CompareFunc) int {
	var ret int
	if comparator != nil {
		ret = comparator(thisValue, compareValue)
	} else {
		ret = DefaultComparator(thisValue, compareValue)
	}
	if reverse {
		ret = -ret
	}
	return ret
}

/*
	交换切片中两个元素的位置
*/
func swap(slice []interface{}, i int, minIndex int) {
	slice[i], slice[minIndex] = slice[minIndex], slice[i]
}
